#!/usr/bin/env python3

import cv2
import numpy as np
import numpy_extra as npe
import rospy
import threading

from ipm import IPM

from cv_bridge import *
from transformations import euler_from_quaternion, quaternion_from_euler
from sensor_msgs.msg import *
from std_msgs.msg import *
from nav_msgs.msg import *
from geometry_msgs.msg import *

class GroundProjectNode(object):
    def __init__(self,
      ns_camera = "/camera_front_t",
      ns_imu = "/imu",
      ns_vehicle = "/vehicle",
      pixels_per_meter = 8.0,
      configuration = "camera_ethernet_imx291"
    ):
        rospy.init_node("ground_project_node")

        self.pixels_per_meter = pixels_per_meter

        self.sub_image_raw = rospy.Subscriber(
            "%s/image_raw" % ns_camera, Image, self.on_image_raw)
        self.sub_imu_data = rospy.Subscriber(
            "%s/data" % ns_imu, Imu, self.on_imu_data, queue_size = 1)
        self.sub_speed = rospy.Subscriber(
            "%s/speed" % ns_vehicle, Float32, self.on_speed, queue_size = 1)

        self.ground_plane_r, self.ground_plane_theta = None, None

        self.is_initialized = False

        self.local_map = np.zeros((1024, 1024, 3), dtype=np.uint8)

        self.map_center_x = 0.0
        self.map_center_y = 0.0

        # current position of the car in the map in meters
        # (0, 0) is the center of the image
        self.cur_y = 0.0
        self.cur_x = 0.0
        self.cur_yaw = 0.0
        self.velocity_x = 0.0

        self.ipm = IPM(pixels_per_meter = self.pixels_per_meter, configuration = configuration)
        self.lock = threading.Lock()
        self.is_initialized = True

    def on_speed(self, msg):
        self.velocity_x = msg.data / 3.6 # km/h to m/s

    def on_imu_data(self, msg):
        dummy_roll, dummy_pitch, self.cur_yaw = euler_from_quaternion(( \
                 msg.orientation.w,
                 msg.orientation.x,
                 msg.orientation.y,
                 msg.orientation.z))

    def on_image_raw(self, msg):
        self.lock.acquire()
        img = imgmsg_to_cv2(msg)

        if self.ground_plane_r is None or self.ground_plane_theta is None:
            self.ground_plane_r, self.ground_plane_theta = self.ipm.get_ground_plane(img.shape)

        cutoff_top = int(img.shape[0] * 0.4)
        cutoff_bottom = int(img.shape[0] * 0.8)

        #cv2.imshow('gr', self.ground_plane_r / 100)

        img = img[cutoff_top:cutoff_bottom,:]
        ground_plane_r = self.ground_plane_r[cutoff_top:cutoff_bottom,:]
        ground_plane_theta = self.ground_plane_theta[cutoff_top:cutoff_bottom,:]

        values_y = self.cur_y + ground_plane_r * np.sin(ground_plane_theta + self.cur_yaw)
        values_x = self.cur_x + ground_plane_r * np.cos(ground_plane_theta + self.cur_yaw)

        values_py = (values_y * self.pixels_per_meter + self.local_map.shape[0] / 2.0).astype(np.int16)
        values_px = (values_x * self.pixels_per_meter + self.local_map.shape[1] / 2.0).astype(np.int16)

        where_valid = (ground_plane_r > 0) & (values_py >= 0) & (values_px >= 0) & \
            (values_py < self.local_map.shape[0]) & (values_px < self.local_map.shape[1])

        values_py = values_py[where_valid]
        values_px = values_px[where_valid]
        img = img[where_valid]

        self.local_map[values_py, values_px, :] = img / 2 # (self.local_map[values_py, values_px, :] + img) / 2

        self.lock.release()

        cv2.imshow('local_map', self.local_map[::-1,:] * 2)
        cv2.waitKey(1)

    def shift_map(self, shift_x, shift_y):
        # shifts local_map and local_map_flat
        # shift_x, shift_y in meters

        self.lock.acquire()
        shift_px = int(shift_x * self.pixels_per_meter)
        shift_py = int(shift_y * self.pixels_per_meter)
        self.cur_x -= shift_x
        self.cur_y -= shift_y
        self.local_map = npe.shift_2d_replace(self.local_map, -shift_px, -shift_py, 0)
        self.lock.release()

    def spin(self):
        rate = rospy.Rate(50)
        seq = 0
        while not rospy.is_shutdown():
            seq += 1
            rate.sleep()
            if not self.is_initialized:
                continue

            self.cur_x += self.velocity_x * 0.01 * np.cos(self.cur_yaw)
            self.cur_y += self.velocity_x * 0.01 * np.sin(self.cur_yaw)

            if seq % 100 != 0:
                continue

            target_x = -np.cos(self.cur_yaw) * self.local_map.shape[1] / self.pixels_per_meter * 0.4
            target_y = -np.sin(self.cur_yaw) * self.local_map.shape[0] / self.pixels_per_meter * 0.4
            shift_x = target_x - self.cur_x
            shift_y = target_y - self.cur_y

            if shift_y != 0.0 and shift_x != 0.0:
                self.shift_map(-shift_x, -shift_y)

if __name__ == "__main__":
    node = GroundProjectNode()
    node.spin()

